{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# Model Zoo -- CNN Gender Classifier (ResNet-50 Architecture, CelebA) with Data Parallelism"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Network Architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The network in this notebook is an implementation of the ResNet-50 [1] architecture on the CelebA face dataset [2] to train a gender classifier.  \n",
    "\n",
    "\n",
    "References\n",
    "    \n",
    "- [1] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778). ([CVPR Link](https://www.cv-foundation.org/openaccess/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html))\n",
    "\n",
    "- [2] Zhang, K., Tan, L., Li, Z., & Qiao, Y. (2016). Gender and smile classification using deep convolutional neural networks. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops (pp. 34-38).\n",
    "\n",
    "The ResNet-50 architecture is similar to the ResNet-34 architecture shown below (from [1]):\n",
    "\n",
    "\n",
    "![](../images/resnets/resnet34/resnet34-arch.png)\n",
    "\n",
    "However, in ResNet-50, the skip connection uses a bottleneck (from [1]):\n",
    "\n",
    "\n",
    "![](../images/resnets/resnet50/resnet50-arch-1.png)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following figure illustrates residual blocks with skip connections such that the input passed via the shortcut matches the dimensions of the main path's output, which allows the network to learn identity functions.\n",
    "\n",
    "![](../images/resnets/resnet-ex-1-1.png)\n",
    "\n",
    "\n",
    "The ResNet-34 architecture actually uses residual blocks with skip connections such that the input passed via the shortcut matches is resized to dimensions of the main path's output. Such a residual block is illustrated below:\n",
    "\n",
    "![](../images/resnets/resnet-ex-1-2.png)\n",
    "\n",
    "The ResNet-50 uses a bottleneck as shown below:\n",
    "\n",
    "![](../images/resnets/resnet-ex-1-3.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a more detailed explanation see the other notebook, [resnet-ex-1.ipynb](resnet-ex-1.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "RANDOM_SEED = 1\n",
    "LEARNING_RATE = 0.0001\n",
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 20\n",
    "\n",
    "# Architecture\n",
    "NUM_FEATURES = 28*28\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "# Other\n",
    "DEVICE = \"cuda:0\"\n",
    "GRAYSCALE = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([128, 1, 28, 28])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=transforms.ToTensor(),\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=BATCH_SIZE, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Batch index: 0 | Batch size: 128\n",
      "Epoch: 2 | Batch index: 0 | Batch size: 128\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(DEVICE)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "for epoch in range(2):\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        \n",
    "        print('Epoch:', epoch+1, end='')\n",
    "        print(' | Batch index:', batch_idx, end='')\n",
    "        print(' | Batch size:', y.size()[0])\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code cell that implements the ResNet-34 architecture is a derivative of the code provided at https://pytorch.org/docs/0.4.0/_modules/torchvision/models/resnet.html."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "def conv3x3(in_planes, out_planes, stride=1):\n",
    "    \"\"\"3x3 convolution with padding\"\"\"\n",
    "    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
    "                     padding=1, bias=False)\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n",
    "                               padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n",
    "        self.bn3 = nn.BatchNorm2d(planes * 4)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv3(out)\n",
    "        out = self.bn3(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            residual = self.downsample(x)\n",
    "\n",
    "        out += residual\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "\n",
    "    def __init__(self, block, layers, num_classes, grayscale):\n",
    "        self.inplanes = 64\n",
    "        if grayscale:\n",
    "            in_dim = 1\n",
    "        else:\n",
    "            in_dim = 3\n",
    "        super(ResNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,\n",
    "                               bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0])\n",
    "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
    "        self.avgpool = nn.AvgPool2d(7, stride=1)\n",
    "        self.fc = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
    "                m.weight.data.normal_(0, (2. / n)**.5)\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                m.weight.data.fill_(1)\n",
    "                m.bias.data.zero_()\n",
    "\n",
    "    def _make_layer(self, block, planes, blocks, stride=1):\n",
    "        downsample = None\n",
    "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
    "            downsample = nn.Sequential(\n",
    "                nn.Conv2d(self.inplanes, planes * block.expansion,\n",
    "                          kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(planes * block.expansion),\n",
    "            )\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
    "        self.inplanes = planes * block.expansion\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(block(self.inplanes, planes))\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.layer4(x)\n",
    "        # because MNIST is already 1x1 here:\n",
    "        # disable avg pooling\n",
    "        #x = self.avgpool(x)\n",
    "        \n",
    "        x = x.view(x.size(0), -1)\n",
    "        logits = self.fc(x)\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "\n",
    "\n",
    "def resnet34(num_classes):\n",
    "    \"\"\"Constructs a ResNet-34 model.\"\"\"\n",
    "    model = ResNet(block=Bottleneck, \n",
    "                   layers=[3, 4, 6, 3],\n",
    "                   num_classes=NUM_CLASSES,\n",
    "                   grayscale=GRAYSCALE)\n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "model = resnet34(NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    " \n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/020 | Batch 0000/0469 | Cost: 2.4814\n",
      "Epoch: 001/020 | Batch 0050/0469 | Cost: 1.6407\n",
      "Epoch: 001/020 | Batch 0100/0469 | Cost: 1.2549\n",
      "Epoch: 001/020 | Batch 0150/0469 | Cost: 0.8013\n",
      "Epoch: 001/020 | Batch 0200/0469 | Cost: 0.7644\n",
      "Epoch: 001/020 | Batch 0250/0469 | Cost: 0.6729\n",
      "Epoch: 001/020 | Batch 0300/0469 | Cost: 0.5566\n",
      "Epoch: 001/020 | Batch 0350/0469 | Cost: 0.4682\n",
      "Epoch: 001/020 | Batch 0400/0469 | Cost: 0.3663\n",
      "Epoch: 001/020 | Batch 0450/0469 | Cost: 0.3904\n",
      "Epoch: 001/020 | Train: 90.932%\n",
      "Time elapsed: 0.66 min\n",
      "Epoch: 002/020 | Batch 0000/0469 | Cost: 0.3419\n",
      "Epoch: 002/020 | Batch 0050/0469 | Cost: 0.2901\n",
      "Epoch: 002/020 | Batch 0100/0469 | Cost: 0.1937\n",
      "Epoch: 002/020 | Batch 0150/0469 | Cost: 0.2761\n",
      "Epoch: 002/020 | Batch 0200/0469 | Cost: 0.2688\n",
      "Epoch: 002/020 | Batch 0250/0469 | Cost: 0.1875\n",
      "Epoch: 002/020 | Batch 0300/0469 | Cost: 0.2401\n",
      "Epoch: 002/020 | Batch 0350/0469 | Cost: 0.1768\n",
      "Epoch: 002/020 | Batch 0400/0469 | Cost: 0.2161\n",
      "Epoch: 002/020 | Batch 0450/0469 | Cost: 0.1624\n",
      "Epoch: 002/020 | Train: 96.738%\n",
      "Time elapsed: 1.34 min\n",
      "Epoch: 003/020 | Batch 0000/0469 | Cost: 0.1413\n",
      "Epoch: 003/020 | Batch 0050/0469 | Cost: 0.0832\n",
      "Epoch: 003/020 | Batch 0100/0469 | Cost: 0.0924\n",
      "Epoch: 003/020 | Batch 0150/0469 | Cost: 0.0587\n",
      "Epoch: 003/020 | Batch 0200/0469 | Cost: 0.0991\n",
      "Epoch: 003/020 | Batch 0250/0469 | Cost: 0.1508\n",
      "Epoch: 003/020 | Batch 0300/0469 | Cost: 0.1367\n",
      "Epoch: 003/020 | Batch 0350/0469 | Cost: 0.1431\n",
      "Epoch: 003/020 | Batch 0400/0469 | Cost: 0.1650\n",
      "Epoch: 003/020 | Batch 0450/0469 | Cost: 0.1842\n",
      "Epoch: 003/020 | Train: 98.288%\n",
      "Time elapsed: 2.03 min\n",
      "Epoch: 004/020 | Batch 0000/0469 | Cost: 0.0812\n",
      "Epoch: 004/020 | Batch 0050/0469 | Cost: 0.0499\n",
      "Epoch: 004/020 | Batch 0100/0469 | Cost: 0.0413\n",
      "Epoch: 004/020 | Batch 0150/0469 | Cost: 0.0929\n",
      "Epoch: 004/020 | Batch 0200/0469 | Cost: 0.0501\n",
      "Epoch: 004/020 | Batch 0250/0469 | Cost: 0.1147\n",
      "Epoch: 004/020 | Batch 0300/0469 | Cost: 0.0277\n",
      "Epoch: 004/020 | Batch 0350/0469 | Cost: 0.0659\n",
      "Epoch: 004/020 | Batch 0400/0469 | Cost: 0.0854\n",
      "Epoch: 004/020 | Batch 0450/0469 | Cost: 0.0368\n",
      "Epoch: 004/020 | Train: 98.942%\n",
      "Time elapsed: 2.71 min\n",
      "Epoch: 005/020 | Batch 0000/0469 | Cost: 0.0120\n",
      "Epoch: 005/020 | Batch 0050/0469 | Cost: 0.0127\n",
      "Epoch: 005/020 | Batch 0100/0469 | Cost: 0.0516\n",
      "Epoch: 005/020 | Batch 0150/0469 | Cost: 0.0341\n",
      "Epoch: 005/020 | Batch 0200/0469 | Cost: 0.0600\n",
      "Epoch: 005/020 | Batch 0250/0469 | Cost: 0.1150\n",
      "Epoch: 005/020 | Batch 0300/0469 | Cost: 0.0312\n",
      "Epoch: 005/020 | Batch 0350/0469 | Cost: 0.0494\n",
      "Epoch: 005/020 | Batch 0400/0469 | Cost: 0.0711\n",
      "Epoch: 005/020 | Batch 0450/0469 | Cost: 0.0531\n",
      "Epoch: 005/020 | Train: 99.060%\n",
      "Time elapsed: 3.39 min\n",
      "Epoch: 006/020 | Batch 0000/0469 | Cost: 0.0589\n",
      "Epoch: 006/020 | Batch 0050/0469 | Cost: 0.0341\n",
      "Epoch: 006/020 | Batch 0100/0469 | Cost: 0.0205\n",
      "Epoch: 006/020 | Batch 0150/0469 | Cost: 0.0219\n",
      "Epoch: 006/020 | Batch 0200/0469 | Cost: 0.0495\n",
      "Epoch: 006/020 | Batch 0250/0469 | Cost: 0.0344\n",
      "Epoch: 006/020 | Batch 0300/0469 | Cost: 0.0298\n",
      "Epoch: 006/020 | Batch 0350/0469 | Cost: 0.0375\n",
      "Epoch: 006/020 | Batch 0400/0469 | Cost: 0.1366\n",
      "Epoch: 006/020 | Batch 0450/0469 | Cost: 0.0469\n",
      "Epoch: 006/020 | Train: 99.312%\n",
      "Time elapsed: 4.07 min\n",
      "Epoch: 007/020 | Batch 0000/0469 | Cost: 0.0116\n",
      "Epoch: 007/020 | Batch 0050/0469 | Cost: 0.0411\n",
      "Epoch: 007/020 | Batch 0100/0469 | Cost: 0.0115\n",
      "Epoch: 007/020 | Batch 0150/0469 | Cost: 0.0110\n",
      "Epoch: 007/020 | Batch 0200/0469 | Cost: 0.1041\n",
      "Epoch: 007/020 | Batch 0250/0469 | Cost: 0.0172\n",
      "Epoch: 007/020 | Batch 0300/0469 | Cost: 0.0614\n",
      "Epoch: 007/020 | Batch 0350/0469 | Cost: 0.0363\n",
      "Epoch: 007/020 | Batch 0400/0469 | Cost: 0.0366\n",
      "Epoch: 007/020 | Batch 0450/0469 | Cost: 0.0660\n",
      "Epoch: 007/020 | Train: 99.482%\n",
      "Time elapsed: 4.76 min\n",
      "Epoch: 008/020 | Batch 0000/0469 | Cost: 0.0132\n",
      "Epoch: 008/020 | Batch 0050/0469 | Cost: 0.0016\n",
      "Epoch: 008/020 | Batch 0100/0469 | Cost: 0.0226\n",
      "Epoch: 008/020 | Batch 0150/0469 | Cost: 0.0283\n",
      "Epoch: 008/020 | Batch 0200/0469 | Cost: 0.0373\n",
      "Epoch: 008/020 | Batch 0250/0469 | Cost: 0.0584\n",
      "Epoch: 008/020 | Batch 0300/0469 | Cost: 0.0115\n",
      "Epoch: 008/020 | Batch 0350/0469 | Cost: 0.0893\n",
      "Epoch: 008/020 | Batch 0400/0469 | Cost: 0.0368\n",
      "Epoch: 008/020 | Batch 0450/0469 | Cost: 0.0184\n",
      "Epoch: 008/020 | Train: 99.192%\n",
      "Time elapsed: 5.44 min\n",
      "Epoch: 009/020 | Batch 0000/0469 | Cost: 0.0047\n",
      "Epoch: 009/020 | Batch 0050/0469 | Cost: 0.0088\n",
      "Epoch: 009/020 | Batch 0100/0469 | Cost: 0.0021\n",
      "Epoch: 009/020 | Batch 0150/0469 | Cost: 0.0861\n",
      "Epoch: 009/020 | Batch 0200/0469 | Cost: 0.0031\n",
      "Epoch: 009/020 | Batch 0250/0469 | Cost: 0.0761\n",
      "Epoch: 009/020 | Batch 0300/0469 | Cost: 0.0123\n",
      "Epoch: 009/020 | Batch 0350/0469 | Cost: 0.0544\n",
      "Epoch: 009/020 | Batch 0400/0469 | Cost: 0.0174\n",
      "Epoch: 009/020 | Batch 0450/0469 | Cost: 0.0093\n",
      "Epoch: 009/020 | Train: 98.930%\n",
      "Time elapsed: 6.13 min\n",
      "Epoch: 010/020 | Batch 0000/0469 | Cost: 0.0164\n",
      "Epoch: 010/020 | Batch 0050/0469 | Cost: 0.0301\n",
      "Epoch: 010/020 | Batch 0100/0469 | Cost: 0.0198\n",
      "Epoch: 010/020 | Batch 0150/0469 | Cost: 0.0171\n",
      "Epoch: 010/020 | Batch 0200/0469 | Cost: 0.1067\n",
      "Epoch: 010/020 | Batch 0250/0469 | Cost: 0.0099\n",
      "Epoch: 010/020 | Batch 0300/0469 | Cost: 0.0169\n",
      "Epoch: 010/020 | Batch 0350/0469 | Cost: 0.0498\n",
      "Epoch: 010/020 | Batch 0400/0469 | Cost: 0.0394\n",
      "Epoch: 010/020 | Batch 0450/0469 | Cost: 0.0366\n",
      "Epoch: 010/020 | Train: 99.385%\n",
      "Time elapsed: 6.81 min\n",
      "Epoch: 011/020 | Batch 0000/0469 | Cost: 0.0049\n",
      "Epoch: 011/020 | Batch 0050/0469 | Cost: 0.0052\n",
      "Epoch: 011/020 | Batch 0100/0469 | Cost: 0.0019\n",
      "Epoch: 011/020 | Batch 0150/0469 | Cost: 0.0270\n",
      "Epoch: 011/020 | Batch 0200/0469 | Cost: 0.0076\n",
      "Epoch: 011/020 | Batch 0250/0469 | Cost: 0.0091\n",
      "Epoch: 011/020 | Batch 0300/0469 | Cost: 0.0114\n",
      "Epoch: 011/020 | Batch 0350/0469 | Cost: 0.0233\n",
      "Epoch: 011/020 | Batch 0400/0469 | Cost: 0.0443\n",
      "Epoch: 011/020 | Batch 0450/0469 | Cost: 0.0027\n",
      "Epoch: 011/020 | Train: 99.693%\n",
      "Time elapsed: 7.50 min\n",
      "Epoch: 012/020 | Batch 0000/0469 | Cost: 0.0361\n",
      "Epoch: 012/020 | Batch 0050/0469 | Cost: 0.0054\n",
      "Epoch: 012/020 | Batch 0100/0469 | Cost: 0.0485\n",
      "Epoch: 012/020 | Batch 0150/0469 | Cost: 0.0220\n",
      "Epoch: 012/020 | Batch 0200/0469 | Cost: 0.0903\n",
      "Epoch: 012/020 | Batch 0250/0469 | Cost: 0.0144\n",
      "Epoch: 012/020 | Batch 0300/0469 | Cost: 0.0148\n",
      "Epoch: 012/020 | Batch 0350/0469 | Cost: 0.0055\n",
      "Epoch: 012/020 | Batch 0400/0469 | Cost: 0.0012\n",
      "Epoch: 012/020 | Batch 0450/0469 | Cost: 0.0228\n",
      "Epoch: 012/020 | Train: 99.530%\n",
      "Time elapsed: 8.18 min\n",
      "Epoch: 013/020 | Batch 0000/0469 | Cost: 0.0038\n",
      "Epoch: 013/020 | Batch 0050/0469 | Cost: 0.0060\n",
      "Epoch: 013/020 | Batch 0100/0469 | Cost: 0.0206\n",
      "Epoch: 013/020 | Batch 0150/0469 | Cost: 0.0092\n",
      "Epoch: 013/020 | Batch 0200/0469 | Cost: 0.0428\n",
      "Epoch: 013/020 | Batch 0250/0469 | Cost: 0.0627\n",
      "Epoch: 013/020 | Batch 0300/0469 | Cost: 0.0374\n",
      "Epoch: 013/020 | Batch 0350/0469 | Cost: 0.0160\n",
      "Epoch: 013/020 | Batch 0400/0469 | Cost: 0.0013\n",
      "Epoch: 013/020 | Batch 0450/0469 | Cost: 0.0477\n",
      "Epoch: 013/020 | Train: 99.625%\n",
      "Time elapsed: 8.86 min\n",
      "Epoch: 014/020 | Batch 0000/0469 | Cost: 0.0087\n",
      "Epoch: 014/020 | Batch 0050/0469 | Cost: 0.0014\n",
      "Epoch: 014/020 | Batch 0100/0469 | Cost: 0.0032\n",
      "Epoch: 014/020 | Batch 0150/0469 | Cost: 0.0096\n",
      "Epoch: 014/020 | Batch 0200/0469 | Cost: 0.0128\n",
      "Epoch: 014/020 | Batch 0250/0469 | Cost: 0.0131\n",
      "Epoch: 014/020 | Batch 0300/0469 | Cost: 0.0137\n",
      "Epoch: 014/020 | Batch 0350/0469 | Cost: 0.0338\n",
      "Epoch: 014/020 | Batch 0400/0469 | Cost: 0.0393\n",
      "Epoch: 014/020 | Batch 0450/0469 | Cost: 0.0372\n",
      "Epoch: 014/020 | Train: 99.483%\n",
      "Time elapsed: 9.55 min\n",
      "Epoch: 015/020 | Batch 0000/0469 | Cost: 0.0263\n",
      "Epoch: 015/020 | Batch 0050/0469 | Cost: 0.0049\n",
      "Epoch: 015/020 | Batch 0100/0469 | Cost: 0.0198\n",
      "Epoch: 015/020 | Batch 0150/0469 | Cost: 0.0455\n",
      "Epoch: 015/020 | Batch 0200/0469 | Cost: 0.0028\n",
      "Epoch: 015/020 | Batch 0250/0469 | Cost: 0.0069\n",
      "Epoch: 015/020 | Batch 0300/0469 | Cost: 0.0319\n",
      "Epoch: 015/020 | Batch 0350/0469 | Cost: 0.0006\n",
      "Epoch: 015/020 | Batch 0400/0469 | Cost: 0.0022\n",
      "Epoch: 015/020 | Batch 0450/0469 | Cost: 0.0024\n",
      "Epoch: 015/020 | Train: 99.795%\n",
      "Time elapsed: 10.24 min\n",
      "Epoch: 016/020 | Batch 0000/0469 | Cost: 0.0010\n",
      "Epoch: 016/020 | Batch 0050/0469 | Cost: 0.0029\n",
      "Epoch: 016/020 | Batch 0100/0469 | Cost: 0.0031\n",
      "Epoch: 016/020 | Batch 0150/0469 | Cost: 0.0041\n",
      "Epoch: 016/020 | Batch 0200/0469 | Cost: 0.0007\n",
      "Epoch: 016/020 | Batch 0250/0469 | Cost: 0.0130\n",
      "Epoch: 016/020 | Batch 0300/0469 | Cost: 0.0172\n",
      "Epoch: 016/020 | Batch 0350/0469 | Cost: 0.0391\n",
      "Epoch: 016/020 | Batch 0400/0469 | Cost: 0.0171\n",
      "Epoch: 016/020 | Batch 0450/0469 | Cost: 0.0763\n",
      "Epoch: 016/020 | Train: 99.533%\n",
      "Time elapsed: 10.92 min\n",
      "Epoch: 017/020 | Batch 0000/0469 | Cost: 0.0575\n",
      "Epoch: 017/020 | Batch 0050/0469 | Cost: 0.0122\n",
      "Epoch: 017/020 | Batch 0100/0469 | Cost: 0.0356\n",
      "Epoch: 017/020 | Batch 0150/0469 | Cost: 0.0309\n",
      "Epoch: 017/020 | Batch 0200/0469 | Cost: 0.0840\n",
      "Epoch: 017/020 | Batch 0250/0469 | Cost: 0.0178\n",
      "Epoch: 017/020 | Batch 0300/0469 | Cost: 0.0083\n",
      "Epoch: 017/020 | Batch 0350/0469 | Cost: 0.0006\n",
      "Epoch: 017/020 | Batch 0400/0469 | Cost: 0.0114\n",
      "Epoch: 017/020 | Batch 0450/0469 | Cost: 0.0281\n",
      "Epoch: 017/020 | Train: 99.777%\n",
      "Time elapsed: 11.62 min\n",
      "Epoch: 018/020 | Batch 0000/0469 | Cost: 0.0116\n",
      "Epoch: 018/020 | Batch 0050/0469 | Cost: 0.0014\n",
      "Epoch: 018/020 | Batch 0100/0469 | Cost: 0.0149\n",
      "Epoch: 018/020 | Batch 0150/0469 | Cost: 0.0258\n",
      "Epoch: 018/020 | Batch 0200/0469 | Cost: 0.0032\n",
      "Epoch: 018/020 | Batch 0250/0469 | Cost: 0.0026\n",
      "Epoch: 018/020 | Batch 0300/0469 | Cost: 0.0010\n",
      "Epoch: 018/020 | Batch 0350/0469 | Cost: 0.0109\n",
      "Epoch: 018/020 | Batch 0400/0469 | Cost: 0.0003\n",
      "Epoch: 018/020 | Batch 0450/0469 | Cost: 0.0052\n",
      "Epoch: 018/020 | Train: 99.540%\n",
      "Time elapsed: 12.30 min\n",
      "Epoch: 019/020 | Batch 0000/0469 | Cost: 0.0215\n",
      "Epoch: 019/020 | Batch 0050/0469 | Cost: 0.0025\n",
      "Epoch: 019/020 | Batch 0100/0469 | Cost: 0.0884\n",
      "Epoch: 019/020 | Batch 0150/0469 | Cost: 0.0038\n",
      "Epoch: 019/020 | Batch 0200/0469 | Cost: 0.0036\n",
      "Epoch: 019/020 | Batch 0250/0469 | Cost: 0.0061\n",
      "Epoch: 019/020 | Batch 0300/0469 | Cost: 0.0015\n",
      "Epoch: 019/020 | Batch 0350/0469 | Cost: 0.0406\n",
      "Epoch: 019/020 | Batch 0400/0469 | Cost: 0.1211\n",
      "Epoch: 019/020 | Batch 0450/0469 | Cost: 0.0135\n",
      "Epoch: 019/020 | Train: 99.617%\n",
      "Time elapsed: 12.98 min\n",
      "Epoch: 020/020 | Batch 0000/0469 | Cost: 0.0983\n",
      "Epoch: 020/020 | Batch 0050/0469 | Cost: 0.0043\n",
      "Epoch: 020/020 | Batch 0100/0469 | Cost: 0.0492\n",
      "Epoch: 020/020 | Batch 0150/0469 | Cost: 0.0634\n",
      "Epoch: 020/020 | Batch 0200/0469 | Cost: 0.0052\n",
      "Epoch: 020/020 | Batch 0250/0469 | Cost: 0.0082\n",
      "Epoch: 020/020 | Batch 0300/0469 | Cost: 0.0044\n",
      "Epoch: 020/020 | Batch 0350/0469 | Cost: 0.0015\n",
      "Epoch: 020/020 | Batch 0400/0469 | Cost: 0.0153\n",
      "Epoch: 020/020 | Batch 0450/0469 | Cost: 0.0085\n",
      "Epoch: 020/020 | Train: 99.685%\n",
      "Time elapsed: 13.67 min\n",
      "Total Training Time: 13.67 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader, device):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
    "                   %(epoch+1, NUM_EPOCHS, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "        \n",
    "\n",
    "    model.eval()\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        print('Epoch: %03d/%03d | Train: %.3f%%' % (\n",
    "              epoch+1, NUM_EPOCHS, \n",
    "              compute_accuracy(model, train_loader, device=DEVICE)))\n",
    "        \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "paaeEQHQj5xC"
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6514,
     "status": "ok",
     "timestamp": 1524976895054,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "gzQMWKq5j5xE",
    "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 98.39%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False): # save memory during inference\n",
    "    print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADU5JREFUeJzt3W+IXfWdx/HPZ2OjwRZ1zGwc0ujEIuuouMkyxGDD0qXbYLUQ80DpKCWL0vRBlS32gX/2wUZBDMu2NQ+WwnQTE7Vru9DGRJC12bBiChocZVZNXXc0TklC/kxIMVaEavLdB3PSnercc6/337mT7/sFw9x7vufPl0M+Offe353zc0QIQD5/VnUDAKpB+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJHVONw+2cOHCGBwc7OYhgVQmJyd1/PhxN7JuS+G3fYOkTZLmSfrXiNhYtv7g4KDGxsZaOSSAEsPDww2v2/TLftvzJP2LpK9LukrSiO2rmt0fgO5q5T3/CklvR8T+iPiDpJ9JWtOetgB0WivhXyzpwIznB4tlf8L2ettjtsempqZaOByAdur4p/0RMRoRwxEx3N/f3+nDAWhQK+E/JGnJjOdfLJYBmANaCf/Lkq6wvdT2fEnflLSzPW0B6LSmh/oi4mPbd0l6TtNDfVsiYl/bOgPQUS2N80fEs5KebVMvALqIr/cCSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QVEuz9NqelPS+pFOSPo6I4XY0BaDzWgp/4W8i4ngb9gOgi3jZDyTVavhD0q9sv2J7fTsaAtAdrb7sXxURh2z/uaRdtv8nIl6YuULxn8J6Sbr00ktbPByAdmnpyh8Rh4rfxyRtl7RilnVGI2I4Iob7+/tbORyANmo6/LbPt/2FM48lrZb0RrsaA9BZrbzsXyRpu+0z+/m3iPiPtnQFoOOaDn9E7Jf0l23sBUAXMdQHJEX4gaQIP5AU4QeSIvxAUoQfSKodf9WXwksvvVSztmnTptJtFy9eXFpfsGBBaX3dunWl9b6+vqZqyI0rP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kxTh/g8rG2icmJjp67Icffri0fsEFF9SsrVy5st3tzBmDg4M1a/fff3/pthluOceVH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSYpy/QU8//XTN2vj4eOm2V199dWl93759pfW9e/eW1nfs2FGz9txzz5Vuu3Tp0tL6u+++W1pvxTnnlP/zGxgYKK0fOHCg6WOXfQdAku69996m9z1XcOUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaTqjvPb3iLpG5KORcQ1xbI+ST+XNChpUtKtEfG7zrVZvaGhoaZqjbj22mtL6yMjI6X1jRs31qxNTk6WbltvnH///v2l9VbMnz+/tF5vnL9e71NTUzVrV155Zem2GTRy5d8q6YZPLLtP0u6IuELS7uI5gDmkbvgj4gVJJz6xeI2kbcXjbZJubnNfADqs2ff8iyLicPH4iKRFbeoHQJe0/IFfRISkqFW3vd72mO2xsvdgALqr2fAftT0gScXvY7VWjIjRiBiOiOH+/v4mDweg3ZoN/05JZ25nu05S7T8rA9CT6obf9lOSXpT0F7YP2r5T0kZJX7M9Ielvi+cA5pC64/wRUWuQ+att7gVNOu+882rWWh3PbvU7DK2odx+D48ePl9avu+66mrXVq1c31dPZhG/4AUkRfiApwg8kRfiBpAg/kBThB5Li1t2ozAcffFBaX7t2bWn99OnTpfVHH320Zm3BggWl22bAlR9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmKcH5XZunVraf3IkSOl9Ysvvri0ftlll33WllLhyg8kRfiBpAg/kBThB5Ii/EBShB9IivADSTHOj4565513atbuueeelvb94osvltYvueSSlvZ/tuPKDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJ1R3nt71F0jckHYuIa4plGyR9W9JUsdoDEfFsp5rE3PXMM8/UrH300Uel295yyy2l9csvv7ypnjCtkSv/Vkk3zLL8RxGxrPgh+MAcUzf8EfGCpBNd6AVAF7Xynv8u26/Z3mL7orZ1BKArmg3/jyV9SdIySYcl/aDWirbX2x6zPTY1NVVrNQBd1lT4I+JoRJyKiNOSfiJpRcm6oxExHBHD/f39zfYJoM2aCr/tgRlP10p6oz3tAOiWRob6npL0FUkLbR+U9I+SvmJ7maSQNCnpOx3sEUAH1A1/RIzMsnhzB3rBHFRvrH779u01a+eee27pto888khpfd68eaV1lOMbfkBShB9IivADSRF+ICnCDyRF+IGkuHU3WrJ5c/mo7549e2rWbrvtttJt+ZPdzuLKDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJMc6PUuPj46X1u+++u7R+4YUX1qw99NBDTfWE9uDKDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJMc6f3IcfflhaHxmZ7c7t/+/UqVOl9dtvv71mjb/XrxZXfiApwg8kRfiBpAg/kBThB5Ii/EBShB9Iqu44v+0lkh6XtEhSSBqNiE22+yT9XNKgpElJt0bE7zrXKppx+vTp0vpNN91UWn/rrbdK60NDQ6X1Bx98sLSO6jRy5f9Y0vcj4ipJKyV91/ZVku6TtDsirpC0u3gOYI6oG/6IOBwRrxaP35f0pqTFktZI2lastk3SzZ1qEkD7fab3/LYHJS2XtFfSoog4XJSOaPptAYA5ouHw2/68pF9I+l5EnJxZi4jQ9OcBs2233vaY7bGpqamWmgXQPg2F3/bnNB38n0bEL4vFR20PFPUBScdm2zYiRiNiOCKG+/v729EzgDaoG37blrRZ0psR8cMZpZ2S1hWP10na0f72AHRKI3/S+2VJ35L0uu0z93F+QNJGSf9u+05Jv5V0a2daRCtOnDhRWn/++edb2v8TTzxRWu/r62tp/+icuuGPiF9Lco3yV9vbDoBu4Rt+QFKEH0iK8ANJEX4gKcIPJEX4gaS4dfdZ4L333qtZW7lyZUv7fvLJJ0vry5cvb2n/qA5XfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IinH+s8Bjjz1Ws7Z///6W9r1q1arS+vS9XjAXceUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQY558DJiYmSusbNmzoTiM4q3DlB5Ii/EBShB9IivADSRF+ICnCDyRF+IGk6o7z214i6XFJiySFpNGI2GR7g6RvS5oqVn0gIp7tVKOZ7dmzp7R+8uTJpvc9NDRUWl+wYEHT+0Zva+RLPh9L+n5EvGr7C5Jesb2rqP0oIv65c+0B6JS64Y+Iw5IOF4/ft/2mpMWdbgxAZ32m9/y2ByUtl7S3WHSX7ddsb7F9UY1t1tsesz02NTU12yoAKtBw+G1/XtIvJH0vIk5K+rGkL0lapulXBj+YbbuIGI2I4YgY7u/vb0PLANqhofDb/pymg//TiPilJEXE0Yg4FRGnJf1E0orOtQmg3eqG39O3Z90s6c2I+OGM5QMzVlsr6Y32twegUxr5tP/Lkr4l6XXb48WyBySN2F6m6eG/SUnf6UiHaMn1119fWt+1a1dpnaG+s1cjn/b/WtJsN2dnTB+Yw/iGH5AU4QeSIvxAUoQfSIrwA0kRfiApbt09B9xxxx0t1YHZcOUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQcEd07mD0l6bczFi2UdLxrDXw2vdpbr/Yl0Vuz2tnbZRHR0P3yuhr+Tx3cHouI4coaKNGrvfVqXxK9Nauq3njZDyRF+IGkqg7/aMXHL9OrvfVqXxK9NauS3ip9zw+gOlVf+QFUpJLw277B9lu237Z9XxU91GJ70vbrtsdtj1Xcyxbbx2y/MWNZn+1dtieK37NOk1ZRbxtsHyrO3bjtGyvqbYnt/7L9G9v7bP99sbzSc1fSVyXnresv+23Pk/S/kr4m6aCklyWNRMRvutpIDbYnJQ1HROVjwrb/WtLvJT0eEdcUy/5J0omI2Fj8x3lRRNzbI71tkPT7qmduLiaUGZg5s7SkmyX9nSo8dyV93aoKzlsVV/4Vkt6OiP0R8QdJP5O0poI+el5EvCDpxCcWr5G0rXi8TdP/eLquRm89ISIOR8SrxeP3JZ2ZWbrSc1fSVyWqCP9iSQdmPD+o3pryOyT9yvYrttdX3cwsFhXTpkvSEUmLqmxmFnVnbu6mT8ws3TPnrpkZr9uND/w+bVVE/JWkr0v6bvHytifF9Hu2XhquaWjm5m6ZZWbpP6ry3DU743W7VRH+Q5KWzHj+xWJZT4iIQ8XvY5K2q/dmHz56ZpLU4vexivv5o16auXm2maXVA+eul2a8riL8L0u6wvZS2/MlfVPSzgr6+BTb5xcfxMj2+ZJWq/dmH94paV3xeJ2kHRX28id6ZebmWjNLq+Jz13MzXkdE138k3ajpT/zfkfQPVfRQo6/LJf138bOv6t4kPaXpl4EfafqzkTslXSxpt6QJSf8pqa+HentC0uuSXtN00AYq6m2Vpl/SvyZpvPi5sepzV9JXJeeNb/gBSfGBH5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpP4Pc0oGVHoLWbQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for batch_idx, (features, targets) in enumerate(test_loader):\n",
    "\n",
    "    features = features\n",
    "    targets = targets\n",
    "    break\n",
    "    \n",
    "    \n",
    "nhwc_img = np.transpose(features[0], axes=(1, 2, 0))\n",
    "nhw_img = np.squeeze(nhwc_img.numpy(), axis=2)\n",
    "plt.imshow(nhw_img, cmap='Greys');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability 7 100.00%\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "logits, probas = model(features.to(device)[0, None])\n",
    "print('Probability 7 %.2f%%' % (probas[0][7]*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "pandas      0.23.4\n",
      "torch       1.0.0\n",
      "PIL.Image   5.3.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
